# A script to split images and their YOLO format labels into tiles, handling edge cases and maintaining proper object annotations with IOU thresholds.

import os
import cv2
import numpy as np

def split_image_and_label(image_folder, label_folder, output_image_folder, output_label_folder, output_padded_folder, tile_size=512, iou_threshold=0.3):
    os.makedirs(output_image_folder, exist_ok=True)
    os.makedirs(output_label_folder, exist_ok=True)
    padded_image_folder = os.path.join(output_padded_folder, "images")
    padded_label_folder = os.path.join(output_padded_folder, "labels")
    os.makedirs(padded_image_folder, exist_ok=True)
    os.makedirs(padded_label_folder, exist_ok=True)

    for image_name in os.listdir(image_folder):
        if image_name.lower().endswith(('.jpg', '.jpeg')):
            base_name = os.path.splitext(image_name)[0]
            image_path = os.path.join(image_folder, image_name)
            label_path = os.path.join(label_folder, base_name + '.txt')

            # Load image
            image = cv2.imread(image_path)
            h, w, _ = image.shape

            # Read label file
            if os.path.exists(label_path):
                with open(label_path, 'r') as f:
                    labels = [line.strip().split() for line in f.readlines()]
            else:
                labels = []

            # Process tiles
            for i in range(0, h, tile_size):
                for j in range(0, w, tile_size):
                    tile_image = image[i:i + tile_size, j:j + tile_size]
                    tile_h, tile_w, _ = tile_image.shape

                    # Pad if necessary
                    if tile_h < tile_size or tile_w < tile_size:
                        padded_image = np.zeros((tile_size, tile_size, 3), dtype=np.uint8)
                        padded_image[:tile_h, :tile_w] = tile_image
                        save_image_path = os.path.join(padded_image_folder, f"{base_name}_{i}_{j}.jpg")
                        cv2.imwrite(save_image_path, padded_image)

                        # Handle padded labels
                        padded_labels = []
                        for label in labels:
                            class_id, x_center, y_center, box_w, box_h = map(float, label)
                            x_center_abs = x_center * w
                            y_center_abs = y_center * h
                            box_w_abs = box_w * w
                            box_h_abs = box_h * h

                            x_min = x_center_abs - box_w_abs / 2
                            y_min = y_center_abs - box_h_abs / 2
                            x_max = x_center_abs + box_w_abs / 2
                            y_max = y_center_abs + box_h_abs / 2

                            # Calculate intersection with the current tile
                            inter_x_min = max(j, x_min)
                            inter_y_min = max(i, y_min)
                            inter_x_max = min(j + tile_size, x_max)
                            inter_y_max = min(i + tile_size, y_max)

                            inter_area = max(0, inter_x_max - inter_x_min) * max(0, inter_y_max - inter_y_min)
                            box_area = (x_max - x_min) * (y_max - y_min)

                            if inter_area / box_area > iou_threshold:
                                new_x_center = (inter_x_min + inter_x_max) / 2 - j
                                new_y_center = (inter_y_min + inter_y_max) / 2 - i
                                new_box_w = inter_x_max - inter_x_min
                                new_box_h = inter_y_max - inter_y_min
                                padded_labels.append([class_id, new_x_center / tile_size, new_y_center / tile_size, new_box_w / tile_size, new_box_h / tile_size])

                        save_label_path = os.path.join(padded_label_folder, f"{base_name}_{i}_{j}.txt")
                        with open(save_label_path, 'w') as f:
                            for label in padded_labels:
                                f.write(" ".join(map(str, label)) + "\n")
                        continue

                    # Filter and transform labels
                    tile_labels = []
                    for label in labels:
                        class_id, x_center, y_center, box_w, box_h = map(float, label)
                        x_center_abs = x_center * w
                        y_center_abs = y_center * h
                        box_w_abs = box_w * w
                        box_h_abs = box_h * h

                        x_min = x_center_abs - box_w_abs / 2
                        y_min = y_center_abs - box_h_abs / 2
                        x_max = x_center_abs + box_w_abs / 2
                        y_max = y_center_abs + box_h_abs / 2

                        # Calculate intersection with the current tile
                        inter_x_min = max(j, x_min)
                        inter_y_min = max(i, y_min)
                        inter_x_max = min(j + tile_size, x_max)
                        inter_y_max = min(i + tile_size, y_max)

                        inter_area = max(0, inter_x_max - inter_x_min) * max(0, inter_y_max - inter_y_min)
                        box_area = (x_max - x_min) * (y_max - y_min)

                        if inter_area / box_area > iou_threshold:
                            new_x_center = (inter_x_min + inter_x_max) / 2 - j
                            new_y_center = (inter_y_min + inter_y_max) / 2 - i
                            new_box_w = inter_x_max - inter_x_min
                            new_box_h = inter_y_max - inter_y_min
                            tile_labels.append([class_id, new_x_center / tile_size, new_y_center / tile_size, new_box_w / tile_size, new_box_h / tile_size])

                    # Save tile image and label
                    save_image_path = os.path.join(output_image_folder, f"{base_name}_{i}_{j}.jpg")
                    save_label_path = os.path.join(output_label_folder, f"{base_name}_{i}_{j}.txt")
                    cv2.imwrite(save_image_path, tile_image)

                    with open(save_label_path, 'w') as f:
                        for label in tile_labels:
                            f.write(" ".join(map(str, label)) + "\n")

# Folder paths
image_folder = r"D:/***"
label_folder = r"D:/***"
output_image_folder = r"D:/***"
output_label_folder = r"D:/***"
output_padded_folder = r"D:/***"

# Run
split_image_and_label(image_folder, label_folder, output_image_folder, output_label_folder, output_padded_folder)
